{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Mnist分类任务：\n",
    "\n",
    "- 网络基本构建与训练方法，常用函数解析\n",
    "\n",
    "- torch.nn.functional模块\n",
    "\n",
    "- nn.Module模块\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 读取Mnist数据集\n",
    "- 会自动进行下载"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from pathlib import Path\n",
    "import requests\n",
    "\n",
    "DATA_PATH = Path(\"data\")\n",
    "PATH = DATA_PATH / \"mnist\"\n",
    "\n",
    "PATH.mkdir(parents=True, exist_ok=True)\n",
    "\n",
    "URL = \"http://deeplearning.net/data/mnist/\"\n",
    "FILENAME = \"mnist.pkl.gz\"\n",
    "\n",
    "if not (PATH / FILENAME).exists():\n",
    "        content = requests.get(URL + FILENAME).content\n",
    "        (PATH / FILENAME).open(\"wb\").write(content)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import pickle\n",
    "import gzip\n",
    "\n",
    "with gzip.open((PATH / FILENAME).as_posix(), \"rb\") as f:\n",
    "        ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding=\"latin-1\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "784是mnist数据集每个样本的像素点个数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(50000, 784)\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAADgpJREFUeJzt3X+MVfWZx/HPs1j+kKI4aQRCYSnE\nYJW4082IjSWrxkzVDQZHrekkJjQapn8wiU02ZA3/VNNgyCrslmiamaZYSFpKE3VB0iw0otLGZuKI\nWC0srTFsO3IDNTjywx9kmGf/mEMzxbnfe+fec++5zPN+JeT+eM6558kNnznn3O+592vuLgDx/EPR\nDQAoBuEHgiL8QFCEHwiK8ANBEX4gKMIPBEX4gaAIPxDUZc3cmJlxOSHQYO5u1SxX157fzO40syNm\n9q6ZPVrPawFoLqv12n4zmybpj5I6JQ1Jel1St7sfSqzDnh9osGbs+ZdJetfd33P3c5J+IWllHa8H\noInqCf88SX8Z93goe+7vmFmPmQ2a2WAd2wKQs3o+8Jvo0OJzh/Xu3i+pX+KwH2gl9ez5hyTNH/f4\ny5KO1dcOgGapJ/yvS7rGzL5iZtMlfVvSrnzaAtBoNR/2u/uImfVK2iNpmqQt7v6H3DoD0FA1D/XV\ntDHO+YGGa8pFPgAuXYQfCIrwA0ERfiAowg8ERfiBoAg/EBThB4Ii/EBQhB8IivADQRF+ICjCDwRF\n+IGgCD8QFOEHgiL8QFCEHwiK8ANBEX4gKMIPBEX4gaAIPxAU4QeCIvxAUIQfCIrwA0ERfiAowg8E\nVfMU3ZJkZkclnZZ0XtKIu3fk0RTyM23atGT9yiuvbOj2e3t7y9Yuv/zy5LpLlixJ1tesWZOsP/XU\nU2Vr3d3dyXU//fTTZH3Dhg3J+uOPP56st4K6wp+5zd0/yOF1ADQRh/1AUPWG3yXtNbM3zKwnj4YA\nNEe9h/3fcPdjZna1pF+b2f+6+/7xC2R/FPjDALSYuvb87n4suz0h6QVJyyZYpt/dO/gwEGgtNYff\nzGaY2cwL9yV9U9I7eTUGoLHqOeyfLekFM7vwOj939//JpSsADVdz+N39PUn/lGMvU9aCBQuS9enT\npyfrN998c7K+fPnysrVZs2Yl173vvvuS9SINDQ0l65s3b07Wu7q6ytZOnz6dXPett95K1l999dVk\n/VLAUB8QFOEHgiL8QFCEHwiK8ANBEX4gKHP35m3MrHkba6L29vZkfd++fcl6o79W26pGR0eT9Yce\neihZP3PmTM3bLpVKyfqHH36YrB85cqTmbTeau1s1y7HnB4Ii/EBQhB8IivADQRF+ICjCDwRF+IGg\nGOfPQVtbW7I+MDCQrC9atCjPdnJVqffh4eFk/bbbbitbO3fuXHLdqNc/1ItxfgBJhB8IivADQRF+\nICjCDwRF+IGgCD8QVB6z9IZ38uTJZH3t2rXJ+ooVK5L1N998M1mv9BPWKQcPHkzWOzs7k/WzZ88m\n69dff33Z2iOPPJJcF43Fnh8IivADQRF+ICjCDwRF+IGgCD8QFOEHgqr4fX4z2yJphaQT7r40e65N\n0g5JCyUdlfSAu6d/6FxT9/v89briiiuS9UrTSff19ZWtPfzww8l1H3zwwWR9+/btyTpaT57f5/+p\npDsveu5RSS+5+zWSXsoeA7iEVAy/u++XdPElbCslbc3ub5V0T859AWiwWs/5Z7t7SZKy26vzawlA\nMzT82n4z65HU0+jtAJicWvf8x81sriRltyfKLeju/e7e4e4dNW4LQAPUGv5dklZl91dJ2plPOwCa\npWL4zWy7pN9JWmJmQ2b2sKQNkjrN7E+SOrPHAC4hFc/53b27TOn2nHsJ69SpU3Wt/9FHH9W87urV\nq5P1HTt2JOujo6M1bxvF4go/ICjCDwRF+IGgCD8QFOEHgiL8QFBM0T0FzJgxo2ztxRdfTK57yy23\nJOt33XVXsr53795kHc3HFN0Akgg/EBThB4Ii/EBQhB8IivADQRF+ICjG+ae4xYsXJ+sHDhxI1oeH\nh5P1l19+OVkfHBwsW3vmmWeS6zbz/+ZUwjg/gCTCDwRF+IGgCD8QFOEHgiL8QFCEHwiKcf7gurq6\nkvVnn302WZ85c2bN2163bl2yvm3btmS9VCrVvO2pjHF+AEmEHwiK8ANBEX4gKMIPBEX4gaAIPxBU\nxXF+M9siaYWkE+6+NHvuMUmrJf01W2ydu/+q4sYY57/kLF26NFnftGlTsn777bXP5N7X15esr1+/\nPll///33a972pSzPcf6fSrpzguf/093bs38Vgw+gtVQMv7vvl3SyCb0AaKJ6zvl7zez3ZrbFzK7K\nrSMATVFr+H8kabGkdkklSRvLLWhmPWY2aGblf8wNQNPVFH53P+7u5919VNKPJS1LLNvv7h3u3lFr\nkwDyV1P4zWzuuIddkt7Jpx0AzXJZpQXMbLukWyV9ycyGJH1f0q1m1i7JJR2V9N0G9gigAfg+P+oy\na9asZP3uu+8uW6v0WwFm6eHqffv2JeudnZ3J+lTF9/kBJBF+ICjCDwRF+IGgCD8QFOEHgmKoD4X5\n7LPPkvXLLktfhjIyMpKs33HHHWVrr7zySnLdSxlDfQCSCD8QFOEHgiL8QFCEHwiK8ANBEX4gqIrf\n50dsN9xwQ7J+//33J+s33nhj2VqlcfxKDh06lKzv37+/rtef6tjzA0ERfiAowg8ERfiBoAg/EBTh\nB4Ii/EBQjPNPcUuWLEnWe3t7k/V77703WZ8zZ86ke6rW+fPnk/VSqZSsj46O5tnOlMOeHwiK8ANB\nEX4gKMIPBEX4gaAIPxAU4QeCqjjOb2bzJW2TNEfSqKR+d/+hmbVJ2iFpoaSjkh5w9w8b12pclcbS\nu7u7y9YqjeMvXLiwlpZyMTg4mKyvX78+Wd+1a1ee7YRTzZ5/RNK/uftXJX1d0hozu07So5Jecvdr\nJL2UPQZwiagYfncvufuB7P5pSYclzZO0UtLWbLGtku5pVJMA8jepc34zWyjpa5IGJM1295I09gdC\n0tV5Nwegcaq+tt/MvijpOUnfc/dTZlVNByYz65HUU1t7ABqlqj2/mX1BY8H/mbs/nz193MzmZvW5\nkk5MtK6797t7h7t35NEwgHxUDL+N7eJ/Iumwu28aV9olaVV2f5Wknfm3B6BRKk7RbWbLJf1G0tsa\nG+qTpHUaO+//paQFkv4s6VvufrLCa4Wconv27NnJ+nXXXZesP/3008n6tddeO+me8jIwMJCsP/nk\nk2VrO3em9xd8Jbc21U7RXfGc391/K6nci90+maYAtA6u8AOCIvxAUIQfCIrwA0ERfiAowg8ExU93\nV6mtra1sra+vL7lue3t7sr5o0aKaesrDa6+9lqxv3LgxWd+zZ0+y/sknn0y6JzQHe34gKMIPBEX4\ngaAIPxAU4QeCIvxAUIQfCCrMOP9NN92UrK9duzZZX7ZsWdnavHnzauopLx9//HHZ2ubNm5PrPvHE\nE8n62bNna+oJrY89PxAU4QeCIvxAUIQfCIrwA0ERfiAowg8EFWacv6urq656PQ4dOpSs7969O1kf\nGRlJ1lPfuR8eHk6ui7jY8wNBEX4gKMIPBEX4gaAIPxAU4QeCIvxAUObu6QXM5kvaJmmOpFFJ/e7+\nQzN7TNJqSX/NFl3n7r+q8FrpjQGom7tbNctVE/65kua6+wEzmynpDUn3SHpA0hl3f6rapgg/0HjV\nhr/iFX7uXpJUyu6fNrPDkor96RoAdZvUOb+ZLZT0NUkD2VO9ZvZ7M9tiZleVWafHzAbNbLCuTgHk\nquJh/98WNPuipFclrXf3581stqQPJLmkH2js1OChCq/BYT/QYLmd80uSmX1B0m5Je9x90wT1hZJ2\nu/vSCq9D+IEGqzb8FQ/7zcwk/UTS4fHBzz4IvKBL0juTbRJAcar5tH+5pN9IeltjQ32StE5St6R2\njR32H5X03ezDwdRrsecHGizXw/68EH6g8XI77AcwNRF+ICjCDwRF+IGgCD8QFOEHgiL8QFCEHwiK\n8ANBEX4gKMIPBEX4gaAIPxAU4QeCavYU3R9I+r9xj7+UPdeKWrW3Vu1Lorda5dnbP1a7YFO/z/+5\njZsNuntHYQ0ktGpvrdqXRG+1Kqo3DvuBoAg/EFTR4e8vePsprdpbq/Yl0VutCumt0HN+AMUpes8P\noCCFhN/M7jSzI2b2rpk9WkQP5ZjZUTN728wOFj3FWDYN2gkze2fcc21m9msz+1N2O+E0aQX19piZ\nvZ+9dwfN7F8L6m2+mb1sZofN7A9m9kj2fKHvXaKvQt63ph/2m9k0SX+U1ClpSNLrkrrd/VBTGynD\nzI5K6nD3wseEzexfJJ2RtO3CbEhm9h+STrr7huwP51Xu/u8t0ttjmuTMzQ3qrdzM0t9Rge9dnjNe\n56GIPf8ySe+6+3vufk7SLyStLKCPlufu+yWdvOjplZK2Zve3auw/T9OV6a0luHvJ3Q9k909LujCz\ndKHvXaKvQhQR/nmS/jLu8ZBaa8pvl7TXzN4ws56im5nA7AszI2W3Vxfcz8UqztzcTBfNLN0y710t\nM17nrYjwTzSbSCsNOXzD3f9Z0l2S1mSHt6jOjyQt1tg0biVJG4tsJptZ+jlJ33P3U0X2Mt4EfRXy\nvhUR/iFJ88c9/rKkYwX0MSF3P5bdnpD0gsZOU1rJ8QuTpGa3Jwru52/c/bi7n3f3UUk/VoHvXTaz\n9HOSfubuz2dPF/7eTdRXUe9bEeF/XdI1ZvYVM5su6duSdhXQx+eY2YzsgxiZ2QxJ31TrzT68S9Kq\n7P4qSTsL7OXvtMrMzeVmllbB712rzXhdyEU+2VDGf0maJmmLu69vehMTMLNFGtvbS2PfePx5kb2Z\n2XZJt2rsW1/HJX1f0n9L+qWkBZL+LOlb7t70D97K9HarJjlzc4N6Kzez9IAKfO/ynPE6l364wg+I\niSv8gKAIPxAU4QeCIvxAUIQfCIrwA0ERfiAowg8E9f/Ex0YKZYOZcwAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from matplotlib import pyplot\n",
    "import numpy as np\n",
    "\n",
    "pyplot.imshow(x_train[0].reshape((28, 28)), cmap=\"gray\")\n",
    "print(x_train.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "注意数据需转换成tensor才能参与后续建模训练\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        ...,\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.]]) tensor([5, 0, 4,  ..., 8, 4, 8])\n",
      "torch.Size([50000, 784])\n",
      "tensor(0) tensor(9)\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "\n",
    "x_train, y_train, x_valid, y_valid = map(\n",
    "    torch.tensor, (x_train, y_train, x_valid, y_valid)\n",
    ")\n",
    "n, c = x_train.shape\n",
    "x_train, x_train.shape, y_train.min(), y_train.max()\n",
    "print(x_train, y_train)\n",
    "print(x_train.shape)\n",
    "print(y_train.min(), y_train.max())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### torch.nn.functional 很多层和函数在这里都会见到\n",
    "\n",
    "torch.nn.functional中有很多功能，后续会常用的。那什么时候使用nn.Module，什么时候使用nn.functional呢？一般情况下，如果模型有可学习的参数，最好用nn.Module，其他情况nn.functional相对更简单一些"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import torch.nn.functional as F\n",
    "\n",
    "loss_func = F.cross_entropy\n",
    "\n",
    "def model(xb):\n",
    "    return xb.mm(weights) + bias"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(10.7988, grad_fn=<NllLossBackward>)\n"
     ]
    }
   ],
   "source": [
    "bs = 64\n",
    "xb = x_train[0:bs]  # a mini-batch from x\n",
    "yb = y_train[0:bs]\n",
    "weights = torch.randn([784, 10], dtype = torch.float,  requires_grad = True) \n",
    "bs = 64\n",
    "bias = torch.zeros(10, requires_grad=True)\n",
    "\n",
    "print(loss_func(model(xb), yb))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 创建一个model来更简化代码"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- 必须继承nn.Module且在其构造函数中需调用nn.Module的构造函数\n",
    "- 无需写反向传播函数，nn.Module能够利用autograd自动实现反向传播\n",
    "- Module中的可学习参数可以通过named_parameters()或者parameters()返回迭代器"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from torch import nn\n",
    "\n",
    "class Mnist_NN(nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.hidden1 = nn.Linear(784, 128)\n",
    "        self.hidden2 = nn.Linear(128, 256)\n",
    "        self.out  = nn.Linear(256, 10)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = F.relu(self.hidden1(x))\n",
    "        x = F.relu(self.hidden2(x))\n",
    "        x = self.out(x)\n",
    "        return x\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Mnist_NN(\n",
      "  (hidden1): Linear(in_features=784, out_features=128, bias=True)\n",
      "  (hidden2): Linear(in_features=128, out_features=256, bias=True)\n",
      "  (out): Linear(in_features=256, out_features=10, bias=True)\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "net = Mnist_NN()\n",
    "print(net)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "可以打印我们定义好名字里的权重和偏置项"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "hidden1.weight Parameter containing:\n",
      "tensor([[ 0.0018,  0.0218,  0.0036,  ..., -0.0286, -0.0166,  0.0089],\n",
      "        [-0.0349,  0.0268,  0.0328,  ...,  0.0263,  0.0200, -0.0137],\n",
      "        [ 0.0061,  0.0060, -0.0351,  ...,  0.0130, -0.0085,  0.0073],\n",
      "        ...,\n",
      "        [-0.0231,  0.0195, -0.0205,  ..., -0.0207, -0.0103, -0.0223],\n",
      "        [-0.0299,  0.0305,  0.0098,  ...,  0.0184, -0.0247, -0.0207],\n",
      "        [-0.0306, -0.0252, -0.0341,  ...,  0.0136, -0.0285,  0.0057]],\n",
      "       requires_grad=True) torch.Size([128, 784])\n",
      "hidden1.bias Parameter containing:\n",
      "tensor([ 0.0072, -0.0269, -0.0320, -0.0162,  0.0102,  0.0189, -0.0118, -0.0063,\n",
      "        -0.0277,  0.0349,  0.0267, -0.0035,  0.0127, -0.0152, -0.0070,  0.0228,\n",
      "        -0.0029,  0.0049,  0.0072,  0.0002, -0.0356,  0.0097, -0.0003, -0.0223,\n",
      "        -0.0028, -0.0120, -0.0060, -0.0063,  0.0237,  0.0142,  0.0044, -0.0005,\n",
      "         0.0349, -0.0132,  0.0138, -0.0295, -0.0299,  0.0074,  0.0231,  0.0292,\n",
      "        -0.0178,  0.0046,  0.0043, -0.0195,  0.0175, -0.0069,  0.0228,  0.0169,\n",
      "         0.0339,  0.0245, -0.0326, -0.0260, -0.0029,  0.0028,  0.0322, -0.0209,\n",
      "        -0.0287,  0.0195,  0.0188,  0.0261,  0.0148, -0.0195, -0.0094, -0.0294,\n",
      "        -0.0209, -0.0142,  0.0131,  0.0273,  0.0017,  0.0219,  0.0187,  0.0161,\n",
      "         0.0203,  0.0332,  0.0225,  0.0154,  0.0169, -0.0346, -0.0114,  0.0277,\n",
      "         0.0292, -0.0164,  0.0001, -0.0299, -0.0076, -0.0128, -0.0076, -0.0080,\n",
      "        -0.0209, -0.0194, -0.0143,  0.0292, -0.0316, -0.0188, -0.0052,  0.0013,\n",
      "        -0.0247,  0.0352, -0.0253, -0.0306,  0.0035, -0.0253,  0.0167, -0.0260,\n",
      "        -0.0179, -0.0342,  0.0033, -0.0287, -0.0272,  0.0238,  0.0323,  0.0108,\n",
      "         0.0097,  0.0219,  0.0111,  0.0208, -0.0279,  0.0324, -0.0325, -0.0166,\n",
      "        -0.0010, -0.0007,  0.0298,  0.0329,  0.0012, -0.0073, -0.0010,  0.0057],\n",
      "       requires_grad=True) torch.Size([128])\n",
      "hidden2.weight Parameter containing:\n",
      "tensor([[-0.0383, -0.0649,  0.0665,  ..., -0.0312,  0.0394, -0.0801],\n",
      "        [-0.0189, -0.0342,  0.0431,  ..., -0.0321,  0.0072,  0.0367],\n",
      "        [ 0.0289,  0.0780,  0.0496,  ...,  0.0018, -0.0604, -0.0156],\n",
      "        ...,\n",
      "        [-0.0360,  0.0394, -0.0615,  ...,  0.0233, -0.0536, -0.0266],\n",
      "        [ 0.0416,  0.0082, -0.0345,  ...,  0.0808, -0.0308, -0.0403],\n",
      "        [-0.0477,  0.0136, -0.0408,  ...,  0.0180, -0.0316, -0.0782]],\n",
      "       requires_grad=True) torch.Size([256, 128])\n",
      "hidden2.bias Parameter containing:\n",
      "tensor([-0.0694, -0.0363, -0.0178,  0.0206, -0.0875, -0.0876, -0.0369, -0.0386,\n",
      "         0.0642, -0.0738, -0.0017, -0.0243, -0.0054,  0.0757, -0.0254,  0.0050,\n",
      "         0.0519, -0.0695,  0.0318, -0.0042, -0.0189, -0.0263, -0.0627, -0.0691,\n",
      "         0.0713, -0.0696, -0.0672,  0.0297,  0.0102,  0.0040,  0.0830,  0.0214,\n",
      "         0.0714,  0.0327, -0.0582, -0.0354,  0.0621,  0.0475,  0.0490,  0.0331,\n",
      "        -0.0111, -0.0469, -0.0695, -0.0062, -0.0432, -0.0132, -0.0856, -0.0219,\n",
      "        -0.0185, -0.0517,  0.0017, -0.0788, -0.0403,  0.0039,  0.0544, -0.0496,\n",
      "         0.0588, -0.0068,  0.0496,  0.0588, -0.0100,  0.0731,  0.0071, -0.0155,\n",
      "        -0.0872, -0.0504,  0.0499,  0.0628, -0.0057,  0.0530, -0.0518, -0.0049,\n",
      "         0.0767,  0.0743,  0.0748, -0.0438,  0.0235, -0.0809,  0.0140, -0.0374,\n",
      "         0.0615, -0.0177,  0.0061, -0.0013, -0.0138, -0.0750, -0.0550,  0.0732,\n",
      "         0.0050,  0.0778,  0.0415,  0.0487,  0.0522,  0.0867, -0.0255, -0.0264,\n",
      "         0.0829,  0.0599,  0.0194,  0.0831, -0.0562,  0.0487, -0.0411,  0.0237,\n",
      "         0.0347, -0.0194, -0.0560, -0.0562, -0.0076,  0.0459, -0.0477,  0.0345,\n",
      "        -0.0575, -0.0005,  0.0174,  0.0855, -0.0257, -0.0279, -0.0348, -0.0114,\n",
      "        -0.0823, -0.0075, -0.0524,  0.0331,  0.0387, -0.0575,  0.0068, -0.0590,\n",
      "        -0.0101, -0.0880, -0.0375,  0.0033, -0.0172, -0.0641, -0.0797,  0.0407,\n",
      "         0.0741, -0.0041, -0.0608,  0.0672, -0.0464, -0.0716, -0.0191, -0.0645,\n",
      "         0.0397,  0.0013,  0.0063,  0.0370,  0.0475, -0.0535,  0.0721, -0.0431,\n",
      "         0.0053, -0.0568, -0.0228, -0.0260, -0.0784, -0.0148,  0.0229, -0.0095,\n",
      "        -0.0040,  0.0025,  0.0781,  0.0140, -0.0561,  0.0384, -0.0011, -0.0366,\n",
      "         0.0345,  0.0015,  0.0294, -0.0734, -0.0852, -0.0015, -0.0747, -0.0100,\n",
      "         0.0801, -0.0739,  0.0611,  0.0536,  0.0298, -0.0097,  0.0017, -0.0398,\n",
      "         0.0076, -0.0759, -0.0293,  0.0344, -0.0463, -0.0270,  0.0447,  0.0814,\n",
      "        -0.0193, -0.0559,  0.0160,  0.0216, -0.0346,  0.0316,  0.0881, -0.0652,\n",
      "        -0.0169,  0.0117, -0.0107, -0.0754, -0.0231, -0.0291,  0.0210,  0.0427,\n",
      "         0.0418,  0.0040,  0.0762,  0.0645, -0.0368, -0.0229, -0.0569, -0.0881,\n",
      "        -0.0660,  0.0297,  0.0433, -0.0777,  0.0212, -0.0601,  0.0795, -0.0511,\n",
      "        -0.0634,  0.0720,  0.0016,  0.0693, -0.0547, -0.0652, -0.0480,  0.0759,\n",
      "         0.0194, -0.0328, -0.0211, -0.0025, -0.0055, -0.0157,  0.0817,  0.0030,\n",
      "         0.0310, -0.0735,  0.0160, -0.0368,  0.0528, -0.0675, -0.0083, -0.0427,\n",
      "        -0.0872,  0.0699,  0.0795, -0.0738, -0.0639,  0.0350,  0.0114,  0.0303],\n",
      "       requires_grad=True) torch.Size([256])\n",
      "out.weight Parameter containing:\n",
      "tensor([[ 0.0232, -0.0571,  0.0439,  ..., -0.0417, -0.0237,  0.0183],\n",
      "        [ 0.0210,  0.0607,  0.0277,  ..., -0.0015,  0.0571,  0.0502],\n",
      "        [ 0.0297, -0.0393,  0.0616,  ...,  0.0131, -0.0163, -0.0239],\n",
      "        ...,\n",
      "        [ 0.0416,  0.0309, -0.0441,  ..., -0.0493,  0.0284, -0.0230],\n",
      "        [ 0.0404, -0.0564,  0.0442,  ..., -0.0271, -0.0526, -0.0554],\n",
      "        [-0.0404, -0.0049, -0.0256,  ..., -0.0262, -0.0130,  0.0057]],\n",
      "       requires_grad=True) torch.Size([10, 256])\n",
      "out.bias Parameter containing:\n",
      "tensor([-0.0536,  0.0007,  0.0227, -0.0072, -0.0168, -0.0125, -0.0207, -0.0558,\n",
      "         0.0579, -0.0439], requires_grad=True) torch.Size([10])\n"
     ]
    }
   ],
   "source": [
    "for name, parameter in net.named_parameters():\n",
    "    print(name, parameter,parameter.size())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 使用TensorDataset和DataLoader来简化"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from torch.utils.data import TensorDataset\n",
    "from torch.utils.data import DataLoader\n",
    "\n",
    "train_ds = TensorDataset(x_train, y_train)\n",
    "train_dl = DataLoader(train_ds, batch_size=bs, shuffle=True)\n",
    "\n",
    "valid_ds = TensorDataset(x_valid, y_valid)\n",
    "valid_dl = DataLoader(valid_ds, batch_size=bs * 2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def get_data(train_ds, valid_ds, bs):\n",
    "    return (\n",
    "        DataLoader(train_ds, batch_size=bs, shuffle=True),\n",
    "        DataLoader(valid_ds, batch_size=bs * 2),\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- 一般在训练模型时加上model.train()，这样会正常使用Batch Normalization和 Dropout\n",
    "- 测试的时候一般选择model.eval()，这样就不会使用Batch Normalization和 Dropout"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 92,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "def fit(steps, model, loss_func, opt, train_dl, valid_dl):\n",
    "    for step in range(steps):\n",
    "        model.train()\n",
    "        for xb, yb in train_dl:\n",
    "            loss_batch(model, loss_func, xb, yb, opt)\n",
    "\n",
    "        model.eval()\n",
    "        with torch.no_grad():\n",
    "            losses, nums = zip(\n",
    "                *[loss_batch(model, loss_func, xb, yb) for xb, yb in valid_dl]\n",
    "            )\n",
    "        val_loss = np.sum(np.multiply(losses, nums)) / np.sum(nums)\n",
    "        print('当前step:'+str(step), '验证集损失：'+str(val_loss))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 93,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from torch import optim\n",
    "def get_model():\n",
    "    model = Mnist_NN()\n",
    "    return model, optim.SGD(model.parameters(), lr=0.001)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 94,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def loss_batch(model, loss_func, xb, yb, opt=None):\n",
    "    loss = loss_func(model(xb), yb)\n",
    "\n",
    "    if opt is not None:\n",
    "        loss.backward()\n",
    "        opt.step()\n",
    "        opt.zero_grad()\n",
    "\n",
    "    return loss.item(), len(xb)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 三行搞定！"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 95,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "当前step:0 验证集损失：2.2796445930480957\n",
      "当前step:1 验证集损失：2.2440698066711424\n",
      "当前step:2 验证集损失：2.1889826164245605\n",
      "当前step:3 验证集损失：2.0985311767578123\n",
      "当前step:4 验证集损失：1.9517273582458496\n",
      "当前step:5 验证集损失：1.7341805934906005\n",
      "当前step:6 验证集损失：1.4719875366210937\n",
      "当前step:7 验证集损失：1.2273896869659424\n",
      "当前step:8 验证集损失：1.0362271406173706\n",
      "当前step:9 验证集损失：0.8963696184158325\n",
      "当前step:10 验证集损失：0.7927186088562012\n",
      "当前step:11 验证集损失：0.7141492074012756\n",
      "当前step:12 验证集损失：0.6529350900650024\n",
      "当前step:13 验证集损失：0.60417300491333\n",
      "当前step:14 验证集损失：0.5643046331882476\n",
      "当前step:15 验证集损失：0.5317994566917419\n",
      "当前step:16 验证集损失：0.5047958114624024\n",
      "当前step:17 验证集损失：0.4813900615692139\n",
      "当前step:18 验证集损失：0.4618900228500366\n",
      "当前step:19 验证集损失：0.4443243554592133\n",
      "当前step:20 验证集损失：0.4297310716629028\n",
      "当前step:21 验证集损失：0.416976597738266\n",
      "当前step:22 验证集损失：0.406348459148407\n",
      "当前step:23 验证集损失：0.3963301926612854\n",
      "当前step:24 验证集损失：0.38733808159828187\n"
     ]
    }
   ],
   "source": [
    "train_dl, valid_dl = get_data(train_ds, valid_ds, bs)\n",
    "model, opt = get_model()\n",
    "fit(25, model, loss_func, opt, train_dl, valid_dl)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
